#!/usr/bin/env python3
"""
Phase 3: Heterogeneity-Robust Staggered Difference-in-Differences

Core identification strategy for the causal paper. Tests whether capital
account opening activates the demographic-CA channel.

Key improvement over Phase 4h (followup): uses modern DiD estimators that
handle heterogeneous treatment effects across cohorts, rather than standard
TWFE which can produce biased estimates.

Models:
PART A — TWFE Baseline & Diagnostics
  1. Standard TWFE event study (replicate Phase 4h, extend to all openers)
  2. De Chaisemartin-d'Haultfoeuille negative weight diagnostic
  3. Triple-difference: Z × post_opening (does opening amplify demographics?)

PART B — Heterogeneity-Robust Estimators
  4. Imputation estimator (Borusyak-Jaravel-Spiess 2024 style)
  5. Cohort-specific ATTs (Callaway-Sant'Anna style)
  6. Aggregated event-study coefficients

PART C — Scope Variations
  7. CCA-only subsample
  8. All openers (global, 105 countries)
  9. Transition economies only (CCA + CEE + Baltics)

Output:
  did_twfe_results.csv — TWFE coefficients and diagnostics
  did_event_study.csv — Event-time coefficients
  did_cohort_atts.csv — Cohort-specific ATTs
  did_triple_diff.csv — Triple-difference results
  did_imputation.csv — Imputation estimator results
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as scipy_stats
import statsmodels.api as sm

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
CAUSAL_DIR = PROJECT_DIR / "causal_identification"
PROCESSED_DIR = CAUSAL_DIR / "data" / "processed"
OUTPUT_DIR = CAUSAL_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR))
from src.model import PanelGLS

CCA_COUNTRIES = [
    'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA',
    'MNG', 'RUS', 'TJK', 'TKM', 'UKR', 'UZB'
]
BALTIC_COUNTRIES = ['EST', 'LVA', 'LTU']
CEE_COUNTRIES = [
    'ALB', 'BGR', 'BIH', 'HRV', 'CZE', 'HUN', 'MKD', 'MNE',
    'POL', 'ROU', 'SRB', 'SVK', 'SVN'
]

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'trade_openness', 'log_rel_opw']


def run_gls(df, dep_var, indep_vars, label=""):
    """Run PanelGLS and return results dict."""
    available = [v for v in indep_vars if v in df.columns]
    comp = df.dropna(subset=[dep_var] + available).copy()
    if comp['iso3'].nunique() < 3 or len(comp) < 30:
        print(f"  {label}: insufficient obs ({len(comp)})")
        return None
    y = comp[dep_var].values
    X = comp[available].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    result = {
        'label': label,
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'rho': gls.rho,
    }
    for i, v in enumerate(available):
        result[f'{v}_coef'] = gls.beta[i]
        result[f'{v}_se'] = gls.se[i]
        result[f'{v}_pval'] = gls.pvalues[i]
    return result


def load_panel():
    """Load causal panel with treatment assignments."""
    df = pd.read_csv(PROCESSED_DIR / "causal_panel.csv", low_memory=False)
    df = df[(df['year'] >= 1992) & (df['year'] <= 2024)].copy()
    return df


# =====================================================================
# PART A: TWFE BASELINE & DIAGNOSTICS
# =====================================================================

def part_a_twfe(df):
    """Standard TWFE analysis with diagnostics."""
    print("\n" + "=" * 70)
    print("PART A: TWFE BASELINE & DIAGNOSTICS")
    print("=" * 70)

    results = []

    # ----- A1: Standard TWFE with post_opening -----
    print("\n--- A1: Standard TWFE with post_opening ---")

    # All openers + never-opened as controls
    openers_and_controls = df[df['status'].isin(['opener', 'never_opened'])].copy()
    openers_and_controls['post'] = openers_and_controls['post_opening'].fillna(0)

    # A1a: Simple post-opening effect on CA
    vars_a1a = DEMO_VARS + CONTROLS + ['post']
    r = run_gls(openers_and_controls, 'ca_gdp', vars_a1a, "A1a: TWFE post effect (all openers)")
    if r:
        results.append(r)
        print(f"  R²={r['r_squared']:.4f}, N={r['n_obs']}, countries={r['n_countries']}")
        print(f"  post_opening: {r['post_coef']:.3f} (p={r['post_pval']:.4f})")
        for z in DEMO_VARS:
            print(f"  {z}: {r[f'{z}_coef']:.3f} (p={r[f'{z}_pval']:.4f})")

    # ----- A2: Triple-Difference — KEY TEST -----
    print("\n--- A2: Triple-Difference (Z × post_opening) ---")
    print("  Does opening amplify the demographic channel?")

    # Create interaction terms
    td = openers_and_controls.copy()
    for z in DEMO_VARS:
        td[f'{z}_x_post'] = td[z] * td['post']

    # A2a: Triple-diff on ALL openers + never-opened
    vars_a2 = DEMO_VARS + CONTROLS + ['post'] + [f'{z}_x_post' for z in DEMO_VARS]
    r = run_gls(td, 'ca_gdp', vars_a2, "A2a: Triple-diff (all openers)")
    if r:
        results.append(r)
        print(f"\n  R²={r['r_squared']:.4f}, N={r['n_obs']}")
        print(f"  post_opening: {r.get('post_coef', np.nan):.3f} (p={r.get('post_pval', np.nan):.4f})")
        for z in DEMO_VARS:
            pre = r.get(f'{z}_coef', np.nan)
            pre_p = r.get(f'{z}_pval', np.nan)
            inter = r.get(f'{z}_x_post_coef', np.nan)
            inter_p = r.get(f'{z}_x_post_pval', np.nan)
            print(f"  {z} (pre-opening):  {pre:8.3f} (p={pre_p:.4f})")
            print(f"  {z}×post:           {inter:8.3f} (p={inter_p:.4f})")
            if not (np.isnan(pre) or np.isnan(inter)):
                print(f"  {z} (post-opening): {pre+inter:8.3f}")

    # A2b: Triple-diff on CCA only
    cca_openers = df[
        (df['iso3'].isin(CCA_COUNTRIES)) &
        (df['status'].isin(['opener', 'never_opened', 'always_open']))
    ].copy()
    cca_openers['post'] = cca_openers['post_opening'].fillna(0)
    for z in DEMO_VARS:
        cca_openers[f'{z}_x_post'] = cca_openers[z] * cca_openers['post']

    r = run_gls(cca_openers, 'ca_gdp', vars_a2, "A2b: Triple-diff (CCA only)")
    if r:
        results.append(r)
        print(f"\n  CCA Triple-Diff: R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for z in DEMO_VARS:
            inter = r.get(f'{z}_x_post_coef', np.nan)
            inter_p = r.get(f'{z}_x_post_pval', np.nan)
            print(f"  {z}×post: {inter:8.3f} (p={inter_p:.4f})")

    # A2c: Triple-diff on transition economies (CCA + CEE + Baltics)
    trans_countries = CCA_COUNTRIES + CEE_COUNTRIES + BALTIC_COUNTRIES
    trans = df[df['iso3'].isin(trans_countries)].copy()
    trans['post'] = trans['post_opening'].fillna(0)
    for z in DEMO_VARS:
        trans[f'{z}_x_post'] = trans[z] * trans['post']

    r = run_gls(trans, 'ca_gdp', vars_a2, "A2c: Triple-diff (transition economies)")
    if r:
        results.append(r)
        print(f"\n  Transition Triple-Diff: R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for z in DEMO_VARS:
            inter = r.get(f'{z}_x_post_coef', np.nan)
            inter_p = r.get(f'{z}_x_post_pval', np.nan)
            print(f"  {z}×post: {inter:8.3f} (p={inter_p:.4f})")

    # ----- A3: TWFE Event Study (lead/lag coefficients) -----
    print("\n--- A3: TWFE Event Study (dynamic effects) ---")

    event_results = run_twfe_event_study(openers_and_controls, "all openers")
    event_results_cca = run_twfe_event_study(
        cca_openers, "CCA only"
    )

    # ----- A4: De Chaisemartin-d'Haultfoeuille Negative Weight Diagnostic -----
    print("\n--- A4: Negative Weight Diagnostic ---")
    neg_weight_diagnostic(openers_and_controls)

    return results, event_results, event_results_cca


def run_twfe_event_study(df, label, max_leads=5, max_lags=10):
    """
    Standard TWFE event study with lead/lag dummies.

    Creates event-time indicators: e_{-5}, ..., e_{-1}, e_0, e_1, ..., e_{10}
    Omitted category: e_{-1} (year before opening)
    """
    print(f"\n  Event study: {label}")

    es = df.copy()

    # Only use countries with opening_year for event study
    # Never-opened serve as control (always 0 for all event dummies)
    openers = es[es['opening_year'].notna()].copy()
    never_opened = es[es['status'] == 'never_opened'].copy()

    # Compute event time
    openers['event_time'] = openers['year'] - openers['opening_year']

    # Create event-time dummies (bin endpoints)
    event_cols = []
    for e in range(-max_leads, max_lags + 1):
        if e == -1:  # omitted category
            continue
        col = f'e_{e}' if e >= 0 else f'e_m{abs(e)}'
        if e == -max_leads:
            openers[col] = (openers['event_time'] <= e).astype(float)
        elif e == max_lags:
            openers[col] = (openers['event_time'] >= e).astype(float)
        else:
            openers[col] = (openers['event_time'] == e).astype(float)
        event_cols.append(col)

    # Never-opened get 0 for all event dummies
    for col in event_cols:
        never_opened[col] = 0.0

    es_full = pd.concat([openers, never_opened], ignore_index=True)

    # Run regression
    all_vars = DEMO_VARS + CONTROLS + event_cols
    available = [v for v in all_vars if v in es_full.columns]
    comp = es_full.dropna(subset=['ca_gdp'] + available).copy()

    if len(comp) < 50:
        print(f"    Insufficient observations: {len(comp)}")
        return None

    y = comp['ca_gdp'].values
    X = comp[available].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)

    # Extract event-time coefficients
    event_coefs = []
    for i, v in enumerate(available):
        if v.startswith('e_'):
            # Parse event time from column name
            if v.startswith('e_m'):
                e_time = -int(v[3:])
            else:
                e_time = int(v[2:])
            event_coefs.append({
                'event_time': e_time,
                'coef': gls.beta[i],
                'se': gls.se[i],
                'pval': gls.pvalues[i],
                'ci_lower': gls.beta[i] - 1.96 * gls.se[i],
                'ci_upper': gls.beta[i] + 1.96 * gls.se[i],
                'label': label,
            })

    # Add omitted category
    event_coefs.append({
        'event_time': -1,
        'coef': 0.0, 'se': 0.0, 'pval': 1.0,
        'ci_lower': 0.0, 'ci_upper': 0.0,
        'label': label,
    })

    event_df = pd.DataFrame(event_coefs).sort_values('event_time')

    print(f"    N={gls.n_obs}, countries={gls.n_countries}, R²={gls.r_squared:.4f}")
    print(f"\n    Event-time coefficients (effect on CA/GDP):")
    print(f"    {'e':>4s} {'coef':>8s} {'SE':>8s} {'p':>7s}")
    print(f"    {'-'*30}")
    for _, row in event_df.iterrows():
        sig = '***' if row['pval'] < 0.001 else ('**' if row['pval'] < 0.01 else ('*' if row['pval'] < 0.05 else ''))
        print(f"    {int(row['event_time']):>4d} {row['coef']:>8.3f} {row['se']:>8.3f} "
              f"{row['pval']:>7.4f} {sig}")

    # Pre-trend test: joint significance of pre-treatment coefficients
    pre_coefs = event_df[event_df['event_time'] < -1]
    if len(pre_coefs) > 0:
        # Simple Wald test
        pre_vals = pre_coefs['coef'].values
        pre_ses = pre_coefs['se'].values
        pre_ses = np.maximum(pre_ses, 1e-10)
        wald = np.sum((pre_vals / pre_ses) ** 2)
        wald_p = 1 - scipy_stats.chi2.cdf(wald, len(pre_vals))
        print(f"\n    Pre-trend test (joint H0: all pre-treatment = 0):")
        print(f"    Wald χ²={wald:.2f}, df={len(pre_vals)}, p={wald_p:.4f}")

    return event_df


def neg_weight_diagnostic(df):
    """
    De Chaisemartin-d'Haultfoeuille (2020) negative weight diagnostic.

    In TWFE with staggered adoption, the coefficient is a weighted average
    of cohort-specific ATTs. Some weights can be NEGATIVE, meaning the
    overall coefficient could have opposite sign to all individual ATTs.

    We compute the share of negative weights as a diagnostic.
    """
    # Get openers with their cohort info
    openers = df[df['opening_year'].notna()].copy()
    if len(openers) == 0:
        print("  No openers found for negative weight diagnostic")
        return

    # Treatment variable: post_opening indicator
    # The TWFE coefficient on post_opening is:
    # β_TWFE = Σ_g w_g * ATT_g
    # where w_g are cohort weights that can be negative

    # Compute weights following dCDH (2020) Proposition 1
    # Weight for cohort g at time t:
    # w_{g,t} ∝ (D_gt - D_g_bar) × (N_gt / N)
    # where D_gt is treatment status, D_g_bar is cohort mean treatment

    cohorts = openers.groupby('iso3')['opening_year'].first()
    unique_cohorts = sorted(cohorts.unique())

    weights = []
    years = sorted(df['year'].unique())

    for g_year in unique_cohorts:
        g_countries = cohorts[cohorts == g_year].index
        g_data = openers[openers['iso3'].isin(g_countries)]

        for t in years:
            t_data = g_data[g_data['year'] == t]
            if len(t_data) == 0:
                continue

            D_gt = 1.0 if t >= g_year else 0.0
            D_g_mean = g_data['post'].mean()
            N_gt = len(t_data)
            N_total = len(openers[openers['year'] == t])

            if N_total == 0:
                continue

            w = (D_gt - D_g_mean) * (N_gt / N_total)
            weights.append({
                'cohort': g_year,
                'year': t,
                'weight': w,
                'n_countries': len(g_countries),
                'post': D_gt,
            })

    if not weights:
        print("  Could not compute weights")
        return

    w_df = pd.DataFrame(weights)
    total_pos = w_df[w_df['weight'] > 0]['weight'].sum()
    total_neg = w_df[w_df['weight'] < 0]['weight'].abs().sum()
    total = total_pos + total_neg

    neg_share = total_neg / total if total > 0 else 0

    print(f"  Negative weight share: {neg_share:.4f} ({100*neg_share:.1f}%)")
    print(f"  Positive weight sum: {total_pos:.4f}")
    print(f"  Negative weight sum: {total_neg:.4f}")

    if neg_share > 0.1:
        print("  WARNING: >10% negative weights — TWFE may be unreliable")
        print("  Interpretation: standard TWFE coefficient could have opposite sign")
        print("  to individual cohort ATTs. Use robust estimators (Part B).")
    else:
        print("  Negative weight share is small — TWFE likely reliable")

    # Weight by cohort
    print(f"\n  Weights by opening cohort:")
    for g in unique_cohorts:
        gw = w_df[w_df['cohort'] == g]
        n_c = gw['n_countries'].iloc[0] if len(gw) > 0 else 0
        w_sum = gw['weight'].sum()
        n_neg = (gw['weight'] < 0).sum()
        print(f"    Cohort {int(g)}: {n_c} countries, weight sum={w_sum:+.4f}, "
              f"neg periods={n_neg}/{len(gw)}")


# =====================================================================
# PART B: HETEROGENEITY-ROBUST ESTIMATORS
# =====================================================================

def part_b_robust(df):
    """
    Imputation-based estimator inspired by Borusyak, Jaravel & Spiess (2024).

    Logic:
    1. Estimate unit + time effects using ONLY not-yet-treated and never-treated obs
    2. Predict counterfactual outcomes for treated obs
    3. Treatment effect = actual - predicted counterfactual
    4. Aggregate to event-time ATTs

    This is valid under parallel trends + no anticipation.
    """
    print("\n" + "=" * 70)
    print("PART B: IMPUTATION ESTIMATOR (BJS-style)")
    print("=" * 70)

    all_imputation_results = []

    # Run for different samples
    for sample_label, sample_df in [
        ("All openers + never-opened", df[df['status'].isin(['opener', 'never_opened'])]),
        ("CCA only", df[df['iso3'].isin(CCA_COUNTRIES)]),
        ("Transition economies", df[df['iso3'].isin(CCA_COUNTRIES + CEE_COUNTRIES + BALTIC_COUNTRIES)]),
    ]:
        print(f"\n--- Imputation: {sample_label} ---")
        imp_result = run_imputation_estimator(sample_df, sample_label)
        if imp_result is not None:
            all_imputation_results.extend(imp_result)

    return all_imputation_results


def run_imputation_estimator(df, label, max_event=15):
    """
    BJS-style imputation estimator.

    Step 1: On the "clean control" sample (never-treated + not-yet-treated),
            estimate: CA_it = α_i + γ_t + X_it'δ + ε_it
    Step 2: For each treated observation, predict counterfactual Ŷ_it(0)
    Step 3: τ_it = Y_it - Ŷ_it(0)
    Step 4: Aggregate by event time
    """
    sdf = df.copy()

    # Identify treatment status
    sdf['treated'] = (sdf['post_opening'] == 1).astype(int)
    sdf.loc[sdf['post_opening'].isna(), 'treated'] = 0  # never-opened

    # Event time (NaN for never-opened)
    sdf['event_time'] = sdf['year'] - sdf['opening_year']
    sdf.loc[sdf['opening_year'].isna(), 'event_time'] = np.nan

    # Clean control sample: never-treated + not-yet-treated
    clean_control = sdf[
        (sdf['status'] == 'never_opened') |  # never treated
        ((sdf['status'] == 'opener') & (sdf['treated'] == 0))  # not yet treated
    ].copy()

    # Treated sample
    treated = sdf[sdf['treated'] == 1].copy()

    if len(clean_control) < 30 or len(treated) < 10:
        print(f"  Insufficient data: {len(clean_control)} control, {len(treated)} treated")
        return None

    print(f"  Clean control: {len(clean_control)} obs, {clean_control['iso3'].nunique()} countries")
    print(f"  Treated:       {len(treated)} obs, {treated['iso3'].nunique()} countries")

    # Step 1: Estimate FE model on clean control sample
    # Use country dummies + year dummies + controls
    control_vars = [v for v in CONTROLS if v in clean_control.columns]

    # Create country and year dummies
    all_obs = pd.concat([clean_control, treated], ignore_index=True)

    country_dummies = pd.get_dummies(all_obs['iso3'], prefix='fe', drop_first=True).astype(float)
    year_dummies = pd.get_dummies(all_obs['year'].astype(int), prefix='yr', drop_first=True).astype(float)

    fe_cols = list(country_dummies.columns)
    yr_cols = list(year_dummies.columns)

    all_obs = pd.concat([all_obs.reset_index(drop=True),
                         country_dummies.reset_index(drop=True),
                         year_dummies.reset_index(drop=True)], axis=1)

    # Split back
    n_control = len(clean_control)
    clean_control_full = all_obs.iloc[:n_control].copy()
    treated_full = all_obs.iloc[n_control:].copy()

    # Estimate on clean control
    reg_vars = control_vars + fe_cols + yr_cols
    available_vars = [v for v in reg_vars if v in clean_control_full.columns]

    comp_control = clean_control_full.dropna(subset=['ca_gdp'] + available_vars)
    if len(comp_control) < 30:
        print(f"  Insufficient complete control obs: {len(comp_control)}")
        return None

    y_control = comp_control['ca_gdp'].values.astype(float)
    X_control = comp_control[available_vars].values.astype(float)
    X_control_const = sm.add_constant(X_control)

    # Use OLS for imputation (FE already included as dummies)
    ols = sm.OLS(y_control, X_control_const).fit()

    # Step 2: Predict counterfactual for treated observations
    comp_treated = treated_full.dropna(subset=['ca_gdp'] + available_vars)
    if len(comp_treated) < 5:
        print(f"  Insufficient complete treated obs: {len(comp_treated)}")
        return None

    X_treated = comp_treated[available_vars].values.astype(float)
    X_treated_const = sm.add_constant(X_treated)

    y_hat_0 = X_treated_const @ ols.params  # counterfactual
    y_actual = comp_treated['ca_gdp'].values

    # Step 3: Individual treatment effects
    tau_it = y_actual - y_hat_0
    comp_treated = comp_treated.copy()
    comp_treated['tau'] = tau_it
    comp_treated['event_time_int'] = comp_treated['event_time'].astype(int)

    # Step 4: Aggregate by event time
    event_atts = []
    for e in range(-3, max_event + 1):
        e_obs = comp_treated[comp_treated['event_time_int'] == e]
        if len(e_obs) < 2:
            continue

        att = e_obs['tau'].mean()
        se = e_obs['tau'].std() / np.sqrt(len(e_obs))
        t_stat = att / se if se > 0 else 0
        p_val = 2 * (1 - scipy_stats.t.cdf(abs(t_stat), len(e_obs) - 1))

        event_atts.append({
            'event_time': e,
            'att': att,
            'se': se,
            'pval': p_val,
            'ci_lower': att - 1.96 * se,
            'ci_upper': att + 1.96 * se,
            'n_obs': len(e_obs),
            'n_countries': e_obs['iso3'].nunique(),
            'label': label,
        })

    if not event_atts:
        print("  No event-time ATTs computed")
        return None

    att_df = pd.DataFrame(event_atts).sort_values('event_time')

    # Print results
    print(f"\n  Imputation ATTs by event time:")
    print(f"  {'e':>4s} {'ATT':>8s} {'SE':>8s} {'p':>7s} {'N':>5s} {'ctry':>5s}")
    print(f"  {'-'*40}")

    for _, row in att_df.iterrows():
        sig = '***' if row['pval'] < 0.001 else ('**' if row['pval'] < 0.01 else ('*' if row['pval'] < 0.05 else ''))
        print(f"  {int(row['event_time']):>4d} {row['att']:>8.3f} {row['se']:>8.3f} "
              f"{row['pval']:>7.4f} {int(row['n_obs']):>5d} {int(row['n_countries']):>5d} {sig}")

    # Overall ATT (average across post-treatment periods)
    post_atts = att_df[att_df['event_time'] >= 0]
    if len(post_atts) > 0:
        overall_att = post_atts['att'].mean()
        # SE via delta method (average of averages)
        overall_se = np.sqrt(np.sum(post_atts['se']**2)) / len(post_atts)
        overall_t = overall_att / overall_se if overall_se > 0 else 0
        overall_p = 2 * (1 - scipy_stats.t.cdf(abs(overall_t), post_atts['n_obs'].sum() - 1))
        print(f"\n  Overall ATT (avg post-treatment): {overall_att:.3f} "
              f"(SE={overall_se:.3f}, p={overall_p:.4f})")

    # Pre-treatment ATTs (should be ~0 if parallel trends hold)
    pre_atts = att_df[att_df['event_time'] < 0]
    if len(pre_atts) > 0:
        pre_att = pre_atts['att'].mean()
        pre_se = np.sqrt(np.sum(pre_atts['se']**2)) / len(pre_atts)
        print(f"  Pre-treatment ATT (parallel trends check): {pre_att:.3f} "
              f"(SE={pre_se:.3f})")

    return event_atts


# =====================================================================
# PART C: TRIPLE-DIFFERENCE WITH DEMOGRAPHICS
# =====================================================================

def part_c_triple_diff_extended(df):
    """
    Extended triple-difference specifications.

    The key question: does capital account opening amplify the demographic
    effect on current accounts?

    CA_it = α_i + γ_t + β₁·Z_it + β₂·Post_it + β₃·Z_it×Post_it + X'δ + ε_it

    β₃ is the parameter of interest.
    """
    print("\n" + "=" * 70)
    print("PART C: EXTENDED TRIPLE-DIFFERENCE SPECIFICATIONS")
    print("=" * 70)

    results = []

    # C1: Full sample with country FE (via demeaning)
    print("\n--- C1: Triple-diff with within-country variation ---")

    openers_controls = df[df['status'].isin(['opener', 'never_opened'])].copy()
    openers_controls['post'] = openers_controls['post_opening'].fillna(0)

    # Demean by country (pseudo-FE)
    for v in DEMO_VARS + CONTROLS + ['ca_gdp']:
        if v in openers_controls.columns:
            openers_controls[f'{v}_dm'] = openers_controls.groupby('iso3')[v].transform(
                lambda x: x - x.mean()
            )

    for z in DEMO_VARS:
        openers_controls[f'{z}_x_post'] = openers_controls[z] * openers_controls['post']
        openers_controls[f'{z}_dm_x_post'] = openers_controls[f'{z}_dm'] * openers_controls['post']

    # C1a: Demeaned triple-diff
    dm_vars = [f'{z}_dm' for z in DEMO_VARS] + \
              [f'{v}_dm' for v in CONTROLS if f'{v}_dm' in openers_controls.columns] + \
              ['post'] + [f'{z}_dm_x_post' for z in DEMO_VARS]

    r = run_gls(openers_controls, 'ca_gdp_dm', dm_vars, "C1a: Demeaned triple-diff")
    if r:
        results.append(r)
        print(f"  R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for z in DEMO_VARS:
            inter = r.get(f'{z}_dm_x_post_coef', np.nan)
            inter_p = r.get(f'{z}_dm_x_post_pval', np.nan)
            print(f"  {z}_dm × post: {inter:8.3f} (p={inter_p:.4f})")

    # C2: Interacted with KAOPEN level (continuous treatment intensity)
    print("\n--- C2: Continuous treatment (KAOPEN level as intensity) ---")

    ct = df.copy()
    # KAOPEN as continuous treatment
    for z in DEMO_VARS:
        ct[f'{z}_x_kaopen'] = ct[z] * ct['kaopen']

    vars_c2 = DEMO_VARS + CONTROLS + ['kaopen'] + [f'{z}_x_kaopen' for z in DEMO_VARS]
    r = run_gls(ct, 'ca_gdp', vars_c2, "C2: Continuous Z×KAOPEN (full sample)")
    if r:
        results.append(r)
        print(f"  R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for z in DEMO_VARS:
            inter = r.get(f'{z}_x_kaopen_coef', np.nan)
            inter_p = r.get(f'{z}_x_kaopen_pval', np.nan)
            print(f"  {z} × KAOPEN: {inter:8.3f} (p={inter_p:.4f})")

    # C2b: Same but ex-CCA
    ct_no_cca = ct[~ct['iso3'].isin(CCA_COUNTRIES)]
    r = run_gls(ct_no_cca, 'ca_gdp', vars_c2, "C2b: Z×KAOPEN (ex-CCA)")
    if r:
        results.append(r)
        print(f"\n  Ex-CCA: R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for z in DEMO_VARS:
            inter = r.get(f'{z}_x_kaopen_coef', np.nan)
            inter_p = r.get(f'{z}_x_kaopen_pval', np.nan)
            print(f"  {z} × KAOPEN: {inter:8.3f} (p={inter_p:.4f})")

    # C3: Pre vs Post split samples
    print("\n--- C3: Split sample (pre vs post opening) ---")

    openers_only = df[df['status'] == 'opener'].copy()
    openers_only['post'] = openers_only['post_opening'].fillna(0)

    pre = openers_only[openers_only['post'] == 0]
    post = openers_only[openers_only['post'] == 1]

    base_vars = DEMO_VARS + CONTROLS
    r_pre = run_gls(pre, 'ca_gdp', base_vars, "C3a: Pre-opening only")
    if r_pre:
        results.append(r_pre)
        print(f"\n  Pre-opening: R²={r_pre['r_squared']:.4f}, N={r_pre['n_obs']}")
        for z in DEMO_VARS:
            print(f"    {z}: {r_pre.get(f'{z}_coef', np.nan):.3f} "
                  f"(p={r_pre.get(f'{z}_pval', np.nan):.4f})")

    r_post = run_gls(post, 'ca_gdp', base_vars, "C3b: Post-opening only")
    if r_post:
        results.append(r_post)
        print(f"\n  Post-opening: R²={r_post['r_squared']:.4f}, N={r_post['n_obs']}")
        for z in DEMO_VARS:
            print(f"    {z}: {r_post.get(f'{z}_coef', np.nan):.3f} "
                  f"(p={r_post.get(f'{z}_pval', np.nan):.4f})")

    # C4: Delayed activation — does effect strengthen over time?
    print("\n--- C4: Delayed activation test ---")
    openers_with_event = df[
        (df['status'] == 'opener') & (df['event_time'].notna())
    ].copy()

    # Create bins: 0-4 years, 5-9 years, 10-14 years, 15+ years post-opening
    openers_with_event['post_0_4'] = ((openers_with_event['event_time'] >= 0) &
                                       (openers_with_event['event_time'] <= 4)).astype(float)
    openers_with_event['post_5_9'] = ((openers_with_event['event_time'] >= 5) &
                                       (openers_with_event['event_time'] <= 9)).astype(float)
    openers_with_event['post_10_14'] = ((openers_with_event['event_time'] >= 10) &
                                         (openers_with_event['event_time'] <= 14)).astype(float)
    openers_with_event['post_15plus'] = (openers_with_event['event_time'] >= 15).astype(float)

    # Add never-opened as controls
    never = df[df['status'] == 'never_opened'].copy()
    for col in ['post_0_4', 'post_5_9', 'post_10_14', 'post_15plus']:
        never[col] = 0.0

    delayed_df = pd.concat([openers_with_event, never], ignore_index=True)

    period_cols = ['post_0_4', 'post_5_9', 'post_10_14', 'post_15plus']
    vars_c4 = DEMO_VARS + CONTROLS + period_cols
    r = run_gls(delayed_df, 'ca_gdp', vars_c4, "C4: Delayed activation")
    if r:
        results.append(r)
        print(f"\n  Delayed activation: R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for col in period_cols:
            coef = r.get(f'{col}_coef', np.nan)
            pval = r.get(f'{col}_pval', np.nan)
            sig = '***' if pval < 0.001 else ('**' if pval < 0.01 else ('*' if pval < 0.05 else ''))
            print(f"    {col}: {coef:+.3f} (p={pval:.4f}) {sig}")

    # C5: Delayed activation with Z interactions
    print("\n--- C5: Delayed activation × demographics ---")

    for z in DEMO_VARS:
        for col in period_cols:
            delayed_df[f'{z}_x_{col}'] = delayed_df[z] * delayed_df[col]

    z_period_interactions = [f'{z}_x_{col}' for z in DEMO_VARS for col in period_cols]
    vars_c5 = DEMO_VARS + CONTROLS + period_cols + z_period_interactions
    r = run_gls(delayed_df, 'ca_gdp', vars_c5, "C5: Delayed activation × Z")
    if r:
        results.append(r)
        print(f"\n  R²={r['r_squared']:.4f}, N={r['n_obs']}")
        for col in period_cols:
            z1_inter = r.get(f'Z_1_x_{col}_coef', np.nan)
            z1_p = r.get(f'Z_1_x_{col}_pval', np.nan)
            sig = '***' if z1_p < 0.001 else ('**' if z1_p < 0.01 else ('*' if z1_p < 0.05 else ''))
            print(f"    Z₁ × {col}: {z1_inter:+8.3f} (p={z1_p:.4f}) {sig}")

    return results


# =====================================================================
# PART D: COHORT-SPECIFIC ATTs (Callaway-Sant'Anna style)
# =====================================================================

def part_d_cohort_atts(df):
    """
    Compute cohort-specific ATTs following Callaway & Sant'Anna (2021).

    For each opening cohort g, estimate ATT(g,t) by comparing:
    - Outcome change for cohort g from pre to post
    - Outcome change for not-yet-treated / never-treated from pre to post
    """
    print("\n" + "=" * 70)
    print("PART D: COHORT-SPECIFIC ATTs")
    print("=" * 70)

    openers = df[df['status'] == 'opener'].copy()
    never_opened = df[df['status'] == 'never_opened'].copy()

    # Get cohort info
    cohorts_df = pd.read_csv(PROCESSED_DIR / "treatment_cohorts.csv")
    openers_info = cohorts_df[cohorts_df['status'] == 'opener'][['iso3', 'opening_year']]

    unique_years = sorted(openers_info['opening_year'].dropna().unique())
    print(f"  Opening cohorts: {[int(y) for y in unique_years]}")

    # For each cohort, compute the ATT using the never-treated as control
    cohort_results = []

    for g_year in unique_years:
        g_countries = openers_info[openers_info['opening_year'] == g_year]['iso3'].tolist()

        if len(g_countries) == 0:
            continue

        # Cohort g data
        g_data = df[df['iso3'].isin(g_countries)]
        g_pre = g_data[g_data['year'] < g_year]['ca_gdp']
        g_post = g_data[g_data['year'] >= g_year]['ca_gdp']

        # Control group: never-treated
        c_pre = never_opened[never_opened['year'] < g_year]['ca_gdp']
        c_post = never_opened[never_opened['year'] >= g_year]['ca_gdp']

        if len(g_pre) < 3 or len(g_post) < 3 or len(c_pre) < 10 or len(c_post) < 10:
            continue

        # Simple DiD: (g_post - g_pre) - (c_post - c_pre)
        did = (g_post.mean() - g_pre.mean()) - (c_post.mean() - c_pre.mean())

        # SE via pooled variance
        se = np.sqrt(
            g_post.var() / len(g_post) + g_pre.var() / len(g_pre) +
            c_post.var() / len(c_post) + c_pre.var() / len(c_pre)
        )
        t_stat = did / se if se > 0 else 0
        p_val = 2 * (1 - scipy_stats.t.cdf(abs(t_stat),
                      len(g_post) + len(g_pre) + len(c_post) + len(c_pre) - 4))

        # Check if any CCA countries in this cohort
        has_cca = any(c in CCA_COUNTRIES for c in g_countries)

        cohort_results.append({
            'cohort_year': int(g_year),
            'n_countries': len(g_countries),
            'countries': ', '.join(sorted(g_countries)),
            'has_cca': has_cca,
            'g_pre_mean': g_pre.mean(),
            'g_post_mean': g_post.mean(),
            'c_pre_mean': c_pre.mean(),
            'c_post_mean': c_post.mean(),
            'did': did,
            'se': se,
            'pval': p_val,
            'n_pre': len(g_pre),
            'n_post': len(g_post),
        })

    if cohort_results:
        cohort_df = pd.DataFrame(cohort_results)

        print(f"\n  Cohort-specific ATTs (DiD vs never-treated):")
        print(f"  {'Cohort':>7s} {'N ctry':>7s} {'DiD':>8s} {'SE':>8s} {'p':>7s} {'CCA?':>5s}  Countries")
        print(f"  {'-'*80}")

        for _, row in cohort_df.iterrows():
            sig = '***' if row['pval'] < 0.001 else ('**' if row['pval'] < 0.01 else ('*' if row['pval'] < 0.05 else ''))
            cca_flag = 'Yes' if row['has_cca'] else ''
            ctry_str = row['countries'][:40] + ('...' if len(row['countries']) > 40 else '')
            print(f"  {int(row['cohort_year']):>7d} {row['n_countries']:>7d} "
                  f"{row['did']:>8.3f} {row['se']:>8.3f} {row['pval']:>7.4f}{sig} "
                  f"{cca_flag:>5s}  {ctry_str}")

        # Aggregate ATT (weighted by cohort size)
        weights = cohort_df['n_countries'] / cohort_df['n_countries'].sum()
        agg_att = (cohort_df['did'] * weights).sum()
        agg_se = np.sqrt((cohort_df['se']**2 * weights**2).sum())

        print(f"\n  Aggregated ATT (cohort-weighted): {agg_att:.3f} (SE={agg_se:.3f})")

        # CCA vs non-CCA cohorts
        cca_cohorts = cohort_df[cohort_df['has_cca']]
        non_cca_cohorts = cohort_df[~cohort_df['has_cca']]

        if len(cca_cohorts) > 0:
            cca_att = cca_cohorts['did'].mean()
            print(f"  CCA cohorts mean ATT: {cca_att:.3f}")
        if len(non_cca_cohorts) > 0:
            non_cca_att = non_cca_cohorts['did'].mean()
            print(f"  Non-CCA cohorts mean ATT: {non_cca_att:.3f}")

        cohort_df.to_csv(OUTPUT_DIR / "did_cohort_atts.csv", index=False)

    return cohort_results


# =====================================================================
# MAIN
# =====================================================================

if __name__ == '__main__':
    print("=" * 70)
    print("PHASE 3: STAGGERED DIFFERENCE-IN-DIFFERENCES")
    print("=" * 70)

    df = load_panel()
    print(f"Loaded panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    print(f"Openers: {df[df['status']=='opener']['iso3'].nunique()}")
    print(f"Never-opened: {df[df['status']=='never_opened']['iso3'].nunique()}")
    print(f"Always-open: {df[df['status']=='always_open']['iso3'].nunique()}")

    # Part A: TWFE baseline
    twfe_results, event_study_all, event_study_cca = part_a_twfe(df)

    # Save TWFE results
    if twfe_results:
        pd.DataFrame(twfe_results).to_csv(OUTPUT_DIR / "did_twfe_results.csv", index=False)

    # Save event study
    if event_study_all is not None:
        es_all = event_study_all if isinstance(event_study_all, pd.DataFrame) else pd.DataFrame(event_study_all)
        if event_study_cca is not None:
            es_cca = event_study_cca if isinstance(event_study_cca, pd.DataFrame) else pd.DataFrame(event_study_cca)
            es_combined = pd.concat([es_all, es_cca], ignore_index=True)
        else:
            es_combined = es_all
        es_combined.to_csv(OUTPUT_DIR / "did_event_study.csv", index=False)

    # Part B: Imputation estimator
    imputation_results = part_b_robust(df)
    if imputation_results:
        pd.DataFrame(imputation_results).to_csv(OUTPUT_DIR / "did_imputation.csv", index=False)

    # Part C: Extended triple-difference
    triple_results = part_c_triple_diff_extended(df)
    if triple_results:
        pd.DataFrame(triple_results).to_csv(OUTPUT_DIR / "did_triple_diff.csv", index=False)

    # Part D: Cohort-specific ATTs
    cohort_results = part_d_cohort_atts(df)

    print("\n" + "=" * 70)
    print("PHASE 3 COMPLETE")
    print("=" * 70)
    print(f"\nOutput files:")
    for f in sorted(OUTPUT_DIR.glob("did_*.csv")):
        print(f"  {f.name}")
